perf(gemma4): ULTIMATE v2 -- ties or beats vLLM no-MTP on 3 models (31B-it, 26B-A4B, E4B), MMLU tied#21
Draft
pyc96 wants to merge 23 commits into
Draft
perf(gemma4): ULTIMATE v2 -- ties or beats vLLM no-MTP on 3 models (31B-it, 26B-A4B, E4B), MMLU tied#21pyc96 wants to merge 23 commits into
pyc96 wants to merge 23 commits into
Conversation
Gemma4MoE.routing_function previously emitted four per-layer GPU kernels:
torch.topk -> at::native::sbtopk::gatherTopK<bf16,uint,2,false>
+ at::native::bitonicSortKVInPlace<2,-1,16,16,bf16,...>
softmax -> at::native::cunn_SoftMaxForward<4,float,...>
per_expert_scale[] -> at::native::index_elementwise_kernel<bf16,...>
topk_weights * ... -> at::native::elementwise_kernel<MulFunctor<bf16>>
cast to fp32 -> at::native::elementwise_kernel<copy>
torch.profiler triage of `Gemma-4-26B-A4B-IT` + Gemma4 MTP on a single
B200 (sm_100a, bf16, --attention-backend triton, --speculative-num-steps 3
--speculative-num-draft-tokens 4 --speculative-eagle-topk 1) attributed
~5.8% of decode GPU time to these split kernels. vLLM (PR
vllm-project/vllm#39083) ships an equivalent single-launch Triton
kernel that does the same logical work in ~1.1% of its decode GPU time.
This commit ports the algorithm to SGLang:
* New `_gemma4_routing_kernel` + `gemma4_fused_routing` in
python/sglang/srt/layers/gemma4_fused_ops.py. One Triton program per
token loads all E logits, packs (bijective(logit_bits), expert_id) into
int64, runs a single `tl.sort`, masks to the K largest, softmaxes in
fp32, multiplies by `per_expert_scale[topk_ids]`, and writes (weights,
ids) in (fp32, int32). num_warps=1 because Gemma4 E=128 fits in a warp.
* `Gemma4MoE.routing_function` now calls the fused kernel on CUDA fp16/
bf16/fp32 inputs and falls back to the torch path otherwise. Math is
bitwise comparable on fp32 inputs and within bf16 round-trip eps for
bf16/fp16.
Real-model results on 1x B200 (host venv SGLang, baseline = PR sgl-project#26026
head + the 3 launch-blocking fixes):
workload baseline this patch delta
chat random 1000/1000 2729.30 tok/s 2880.94 tok/s +5.6%
summariz. random 8000/1000 1060.98 tok/s 1108.42 tok/s +4.5%
chat median TPOT (ms) 21.11 20.70 -1.9%
chat accept length 2.75 2.80 +1.8%
MMLU @ 500 random questions (seed 0, temp 0): 0.708 vs vLLM 0.710 -- no
quality regression.
Tests: test/srt/layers/test_gemma4_fused_routing.py exercises 47
shape/dtype combinations against the previous torch routing function.
Provenance: algorithm follows vLLM `_gemma4_routing_kernel` (apache-2.0,
PR vllm-project/vllm#39083); kernel rewritten from scratch in SGLang
style.
Co-authored-by: Claude
…l split Gemma-4 textual layers are a 25:5 SWA:full split (see `Gemma4TextConfig.layer_types`). SGLang's default `swa_full_tokens_ratio=0.8` is tuned for models where the sliding-window pool is the binding constraint; for Gemma-4 the **full-attention** pool is binding under any realistic concurrent long-context workload. On a 180 GB B200 with TP=1, bf16, MTP (assistant draft model), 16 k context, the default pool layout solves to: full_layer_tokens = 593_956 <-- fits ~65 concurrent 9k-token requests swa_layer_tokens = 475_164 <-- fits ~464 concurrent 1024-token windows A typical 80-prompt summarization workload (8 k input + 1 k output = 9 k tokens / request) needs ~720 k full-attention tokens. Because the full pool is too small, the scheduler partially evicts the KV of in-flight requests and re-prefills them later, visible in the serving log as: Prefill batch, ..., #cached-token: 1003, #new-token: 7010, ... These re-prefills inflate TTFT well past the measured per-step prefill GPU time. Setting `swa_full_tokens_ratio = 0.15` (matching the precedent in `apply_deepseek_v4_defaults`) shifts memory from the over-provisioned SWA pool to the under-provisioned full pool: full_layer_tokens = 2_138_243 <-- fits ~237 concurrent 9k-token reqs swa_layer_tokens = 320_736 <-- still ~313 1024-token windows Real-model results on the same B200 (host venv SGLang, baseline = PR #1 on pyc96/sglang head = sota-loop-base + fused router): workload Patch 1 this patch delta chat random 1000/1000 2881 tok/s 2913 tok/s +1.1 % summariz. random 8000/1000 median TTFT (ms) 10459 8763 **-16.2 %** output tok/s 1108 1097 -1.0 % median TPOT (ms) 44.6 37.9 -15.0 % Median summarization TTFT now matches vLLM nightly (8763 ms vs vLLM 8916 ms, within run-to-run noise). MMLU @ 500 random questions (seed 0, temp 0): SGLang 0.706 vs vLLM 0.710 -- within MMLU sampling noise; no regression. User override of `--swa-full-tokens-ratio` is preserved (mirrors the guard in `apply_deepseek_v4_defaults`). Tests: test/srt/test_gemma4_swa_full_tokens_ratio.py exercises the override-fires and user-override-preserved paths; 3 passed, 1 smoke test skipped on environments that do not have full ModelConfig stubs. Co-authored-by: Claude
Opt-in bounds-check before flashinfer trtllm_batch_decode_with_kv_cache that traps OOB page indices and dumps page_table + cache_seqlens. Turns the async CUDA illegal-address error into a deterministic Python exception with a serialisable dump for post-mortem. See crash_repro/TRIAGE_REPORT.md and crash_repro/repro_e4b_bounds.sh. Co-authored-by: Claude
…rap)
Adds an opt-in trap inside SWATokenToKVPoolAllocator.alloc_extend and
alloc_decode that fires when the SWA paged allocator returns a token
index >= swa_pool_size, and dumps the offending alloc_swa_indices.
Same env var (SGLANG_TRTLLM_MHA_DEBUG=1) as the trtllm_mha bounds
check. Independent of attention backend, so we can run this on triton
and trtllm_mha side-by-side and compare.
Empirical result from running this on Gemma-4-E4B-IT + MTP +
summarisation 8 k/1 k x 80 prompts:
triton backend: SWA usage reaches 1.00, ZERO trap fires, no crash
trtllm_mha backend: SWA usage 0.83-0.86, ZERO trap fires either, but
CUDA illegal address crash in fmhaSm100fKernel_*
That is, the SWA allocator is NOT the source of the OOB. Both backends
write the same valid swa indices; what differs is how trtllm_mha's
init_forward_metadata builds the page_table. Specifically:
metadata.page_table = req_to_token[req_pool_indices, :max_seq_len_k]
For rows where cache_seqlens_int32[row] < max_seq_len_k, the trailing
positions are unwritten (zeros in req_to_token). full_to_swa_index_mapping[0]
is the swa slot most recently bound to full slot 0, which can address
any swa page (in-bounds for the SWA buffer, but the trtllm_mha kernel
treats the row as the *whole* sequence-length window and dereferences
it).
This commit ships only the instrumentation, not a fix; the fix path
(mask trailing page_table entries before translation OR use windowed
indices like the triton backend) is recorded in
crash_repro/TRIAGE_REPORT.md.
Co-authored-by: Claude
…A crash
Prevents the deterministic CUDA Warp Illegal Address crash in
'fmhaSm100fKernel_*SlidingOrChunkedCausal*' that triggers under
Gemma-4 + --attention-backend trtllm_mha + MTP + summarization
workloads at ~85% SWA pool utilization (see
crash_repro/TRIAGE_REPORT.md).
Root cause: the full_to_swa_index_mapping accumulates entries that
become invalid in certain MTP draft-token allocation patterns; after
//page_size, the resulting swa_page_table can contain values >=
num_swa_pages, which the trtllm SWA kernel TMA-prefetches and traps on.
Fix: clamp page_table values to [0, k_cache.shape[0] - 1] right
before the kernel call in both forward_decode and forward_extend.
Applies to BOTH the regular page_table and swa_page_table paths.
Verification on Gemma-4-E4B-IT + trtllm_mha + MTP + summarization
(8 k/1 k x 80 prompts, max_concurrency=64):
before this fix: CRASH at ~85% SWA fill, ~30 s into bench
after this fix: COMPLETED, output 4032 tok/s peak, no trap events
Verification on Gemma-4-26B-A4B-IT + trtllm_mha + MTP + summarization
(8 k/1 k x 80 prompts, max_concurrency=64):
before: CRASH (same kernel, same SWA fill trigger)
after: COMPLETED, output 1832 tok/s peak (vs Patch 1+2 triton
1097 tok/s = +67%), TPOT 25 ms (vs triton 38 ms = -34%),
TTFT 2.9 s (vs triton 8.8 s = -67%)
MMLU @ 500 questions on 26B with this fix: 0.718 (vs Patch 2 baseline
0.706, vLLM 0.710) -- within noise, no regression.
KNOWN LIMITATION: accept length drops vs triton backend (1.69 vs 2.76
on 26B summarization). Clamped page indices that fall in the attention
window cause the kernel to read the LAST valid SWA page's K/V instead
of the correct one, producing slightly wrong attention values for
those positions. The clamp is a defensive safety net, not a complete
fix; the underlying ownership of stale full_to_swa_index_mapping
entries needs upstream investigation (filed in
humanize/source-idea-ledger.md as Patch E). For workloads where the
quality regression is acceptable (or workloads that don't hit the
near-pool-full edge), this fix unlocks the trtllm_mha attention
backend with MTP -- which is otherwise unusable.
Cost: one clamp() per kernel call (~few microseconds, no measurable
perf impact).
See crash_repro/TRIAGE_REPORT.md.
Co-authored-by: Claude
Root-cause fix for the SWA-aware page_table OOB that crashed
trtllm_mha + MTP + hybrid-SWA models (Gemma-4 26B-A4B-IT, E4B-IT).
The TRTLLMHAAttnBackend caches use_sliding_window_kv_pool and
_swa_kv_pool at __init__ time from model_runner.token_to_kv_pool.
For the FROZEN_KV_MTP draft worker, the draft model_runner's pool is
NOT an SWAKVPool (the draft model is a small assistant); so those
SWA-aware attributes are set to (False, None) at init.
At forward time, frozen_kv_target_view / target_kv_pool_view
swap draft_attn_backend.token_to_kv_pool to the target's
SWAKVPool, but the cached SWA-aware attributes are NOT updated.
The backend then builds full-pool page_table values for layers
that the assistant remaps to SWA layers (via
Gemma4Assistant.bind_frozen_kv_context: assistant SWA layers all
point at target physical layer 22 via the KV-shared owner map), and
the trtllm_mha sm_100a paged-attention kernel
(fmhaSm100fKernel_*SlidingOrChunkedCausal*) reads those
out-of-range page indices from the SWA k_cache (only 8657 pages on
E4B) and traps with Warp Illegal Address.
Definitive evidence captured by the Patch-E investigation:
[Patch-E DEBUG] backend has use_sliding_window_kv_pool=False,
_swa_kv_pool is None? True,
layer_id=22, layer.sliding_window_size=512
The fix has two parts:
1. frozen_kv_mtp_utils.py: add _maybe_swap_swa_state /
_restore_swa_state helpers and wire them into both
frozen_kv_target_view and target_kv_pool_view so the
backend's use_sliding_window_kv_pool and _swa_kv_pool
attributes flip in lockstep with the token_to_kv_pool swap.
2. trtllm_mha_backend.py: add self.model_has_sliding_window
computed from model_runner.sliding_window_size and use it in
_alloc_swa_page_table so the SWA page_table buffer is
eagerly allocated even when the backend's pool is non-SWA at
init. This is required for the FROZEN_KV_MTP cuda-graph capture
path which binds the buffer at replay time.
3. frozen_kv_mtp_cuda_graph_runner.py: also swap SWA state during
the cuda-graph capture wrapper (the manual swap there mirrors the
context-manager pattern).
Results on Gemma-4 + trtllm_mha + MTP + summarization (random 8 k/1 k
× 80 prompts, max-concurrency=64 for E4B / unbounded for 26B):
E4B | clamp PR #5 | this PR (proper) | delta
-----|-------------|------------------|-------
outcome OK OK same
output tok/s 4032 4022 ~same
accept length 1.61 **2.13** +32%
total throughput 31.5 k tok/s 36.2 k tok/s +15%
median TPOT (ms) 12.16 9.99 -18%
26B | clamp PR #5 | this PR (proper) | delta
-----|-------------|------------------|-------
outcome OK OK same
output tok/s 1832 2503 +37%
accept length 1.67 **2.84** +70%
total throughput 16.5 k tok/s 22.5 k tok/s +37%
median TPOT (ms) 24.97 20.35 -18%
median TTFT (ms) 2887 3468 +20%
benchmark duration ~60 s 32 s -47%
26B beats the triton baseline (1097 tok/s, TPOT 37.87 ms, accept 2.76)
by +128%, -46%, +3% respectively. MMLU @ 500 questions: 0.716 (vs
triton baseline 0.706, vLLM 0.710) -- within sampling noise.
26B chat 1000/1000: TTFT 510 ms (vs vLLM 880 ms), TPOT 8.72 ms (vs
vLLM 8.46 ms), accept 2.89 (vs vLLM 2.80).
This makes the defensive clamp from #5 unnecessary; that
PR can be reverted (or kept as a belt-and-suspenders safety net).
Co-authored-by: Claude
This reverts commit 5547e41. PR #5 (the clamp) is no longer needed because PR #6 (Patch E) eliminates the source of OOB page_table values entirely. The clamp's only side-effect was a known quality limitation -- when the clamp actually triggered, it replaced an OOB page index with the LAST valid SWA page, producing slightly wrong attention values for that position and lowering MTP draft acceptance. With Patch E in place those OOB values never occur and the clamp never fires, so it's dead code that adds one .clamp() per kernel call for no benefit. Verified after this revert (Gemma-4-E4B-IT + trtllm_mha + MTP + summarization 8 k/1 k x 80 on 1x B200): outcome: OK (zero trap events from PR #3 debug) accept length: matches the pre-revert PR #6 run TPOT: matches the pre-revert PR #6 run If a future code change reintroduces an OOB page_table value, the opt-in bounds-check trap from PR #3 (SGLANG_TRTLLM_MHA_DEBUG=1) will still catch it with a deterministic Python exception + dump for triage. Co-authored-by: Claude
Patch 2 (PR #2) set swa_full_tokens_ratio=0.15 for every Gemma-4 model. That value was tuned for `Gemma-4-26B-A4B-IT` (MoE, 128 experts, top-k 8) where the MoE sparsity leaves plenty of GPU memory for the full-attention KV pool, and the 5:1 SWA:full layer ratio means the shipped default 0.8 over-provisions the SWA pool. For dense Gemma-4 variants (`31B-it`, `E4B-IT`) the same ratio is harmful: dense weights take more GPU memory, leaving less for KV, so 0.15 shrinks the SWA pool below what an 80-request concurrent workload needs. Empirically (on `gemma-4-31B-it` + trtllm_mha + MTP + 1x B200 with 80 concurrent 1k/1k chat requests): ratio=0.15: SWA pool 71808 tokens (~70 windows-worth), saturates at 100%, scheduler stalls admission, output throughput collapses to ~1135 tok/s. ratio=0.8: SWA pool 106368 tokens (~104 windows-worth), still saturates at 80 concurrent reqs but at conc=32 the workload runs to completion at 4715 tok/s -- beats vLLM's 4077 tok/s on the same workload. This commit gates the 0.15 override on `num_experts > 0`, read from the model's `hf_text_config`. Mirrors the MoE-detection pattern in `gemma4_causal.py:1166`. Per-model verification on 1x B200: 26B-A4B-IT (MoE, num_experts=128): log: 'Setting swa_full_tokens_ratio to 0.15 for ... ' pool: full_layer_tokens=2138240 swa_layer_tokens=320704 (unchanged from Patch 2 -- regression-safe) 31B-it (dense, num_experts=0): log: 'Keeping default swa_full_tokens_ratio=0.8 ... ' pool: full_layer_tokens=132992 swa_layer_tokens=106368 (instead of the broken 478720 / 71808 layout from Patch 2) E4B-IT (dense, num_experts=0): same MoE-only-skipped path as 31B. Benchmark improvements on 31B-it + trtllm_mha + MTP + 1x B200 vs vLLM nightly (random 40 prompts x 1k/1k chat, max-concurrency=32): metric | SGLang (this PR) | vLLM nightly | Delta ------------------|------------------|--------------|---- outcome | OK | OK | same median TTFT | 673 ms | 901 ms | SGLang +25% median TPOT | 8.69 ms | 9.69 ms | SGLang +10% total throughput | 4715 tok/s | 4077 tok/s | SGLang +16% accept length | 3.13 | n/a | -- Same workload at conc=32 summarization (8k/1k x 40): median TPOT | 17.02 ms | 27.33 ms | SGLang +38% total throughput | 7475 tok/s | 6468 tok/s | SGLang +16% MMLU @ 500 questions on 31B-it: 0.680 vs vLLM 0.660 (within noise). Tests: 6 unit-test cases now cover (moe-default-overridden, dense-default-preserved, moe-user-override-preserved x 2 archs, moe-full-smoke, dense-full-smoke). Co-authored-by: Claude
…CG opt-in)
Three independent changes to close the SGLang \u2194 vLLM TPOT gap when
serving Gemma4 with the triton attention backend:
1. Fused PLE-tail kernels (gemma4_fused_ops.py)
Adds two new Triton kernels:
* gemma_rmsnorm_add(x, w, r) : out = rmsnorm(x,w) + r
* gemma_gelu_tanh_mul(gate, ple) : out = gelu_tanh(gate) * ple
Re-uses gemma_rmsnorm_residual_scalar for the 3rd tail stage. The
PLE branch in Gemma4DecoderLayer.forward (taken when has_ple=True,
i.e. E2B / E4B) used to issue 7 launches at the layer tail
(post_ff_norm; add residual; gate gelu; mul ple; project norm;
add+mul). The two GEMMs around the PLE input are unavoidable; the
remaining five pointwise ops collapse into three Triton launches.
For E2B (35 layers) that's ~140 launches saved per decode step.
2. Optional key/value in unified_attention_with_output (radix_attention.py)
The piecewise/breakable CUDA graph attention wrapper sliced key /
value unconditionally, which crashed on Gemma4 E2B / E4B KV-shared
layers (those pass key=None, value=None and read both from the cache
written by an earlier layer). The custom op now declares the args as
Optional[torch.Tensor] and skips the slice when None.
3. Piecewise CUDA graph opt-in for multimodal models (server_args.py)
The blanket disable for is_multimodal=True is too coarse: the
piecewise CG runner already extracts model.language_model explicitly,
so the vision tower stays eager while the language-model decode path
gets piecewise capture. Default behavior is unchanged; opt in with
SGLANG_ENABLE_PIECEWISE_CUDA_GRAPH_FOR_MM=1 to pick up the prefill
capture. Safe today on Gemma-4-26B-A4B-IT (no KV-shared layers).
Benchmark (1\u00d7 B200, vllm bench serve random text 3000-input/100-output,
30 prompts, vLLM nightly comparator):
Gemma-4-26B-A4B-IT (--enforce-piecewise-cuda-graph + this PR):
baseline dur 1.475s | TPOT 10.97ms | tok/s 63325
patched dur 1.405s | TPOT 9.80ms | tok/s 66438
vLLM nightly dur 1.635s | TPOT 9.99ms | tok/s 58420
-> SGLang patched now beats vLLM TPOT (9.80 vs 9.99 ms) and
wall-time (1.405 vs 1.635 s) on this workload.
gemma-4-E2B-it (fused PLE only; piecewise CG still disabled on E2B
because of a separate KV-shared / capture interaction):
baseline dur 0.895s | TPOT 5.44ms | tok/s 104329
patched dur 0.875s | TPOT 5.20ms | tok/s 105861
vLLM nightly dur 0.735s | TPOT 3.75ms | tok/s 127468
Quality (30-prompt color-naming MM test, temperature=0):
26B baseline 30/30 == patched 30/30 (29/30 char-match, 1 minor
numerical noise from PCG capture, accuracy unchanged).
E2B baseline 26/30 == patched 26/30 (30/30 char-match on the
fused-PLE-only build).
Test: test/srt/layers/test_gemma4_ple_fused_ops.py (10 CUDA tests).
Refs: vllm-project/vllm uses analogous Inductor-level fusions in its
piecewise compile pipeline; this PR ports the highest-impact subset
directly into SGLang's Triton kernel library so Gemma4 closes the
TPOT gap without depending on Inductor.
…re-MoE)
Inspects vLLM's torch.compile/Inductor output for Gemma-4-26B-A4B-IT
(via TORCH_COMPILE_DEBUG=1) and ports the highest-impact fused kernel
into SGLang's Triton kernel library.
The Inductor kernel `triton_red_fused_add_moe_forward_mul_rms_norm_0`
fuses the entire post-attention-pre-MoE block:
1) post_attn_residual = rmsnorm(attn_out, w_post_attn) + residual
2) dense_ff_input = rmsnorm(post_attn_residual, w_pre_ff)
3) router_input = rmsnorm(post_attn_residual, 1) * router_scale
4) moe_input = rmsnorm(post_attn_residual, w_pre_ff_2)
Steps 2, 3, 4 share the same rsqrt(variance(post_attn_residual));
Inductor walks the row twice for reductions and once for production,
emitting all three outputs from a single kernel.
This commit:
* adds `gemma_post_attn_triple_rmsnorm` in gemma4_fused_ops.py
that replicates the 3-pass-reduction layout in Triton.
* wires Gemma4DecoderLayer.forward (MoE branch) to call it instead
of the 4 separate kernel launches (post_attn_norm; pre_ff_norm
fused-add; router.norm + scale; pre_ff_norm_2).
* adds 4 CUDA-only unit tests against an eager reference.
Eligibility gates (falls back to the original 4-launch sequence):
* MoE branch active (enable_moe_block=True)
* 2D contiguous bf16 hidden_states (the common decode path)
* Gemma4Router with with_scale=False norm (the canonical setup)
* Lazily populates router._fused_scale on the first call.
Benchmark (1x B200, vllm bench serve random, vLLM nightly comparator,
SGLANG_ENABLE_PIECEWISE_CUDA_GRAPH_FOR_MM=1 to enable PR #16's
piecewise CG):
Gemma-4-26B-A4B-IT workload A (3000-input / 100-output, 30 prompts):
baseline dur 1.475s | TPOT 10.97ms | tok/s 63325
PR #16 only dur 1.406s | TPOT 9.80ms | tok/s 66437
+ this PR dur 1.376s | TPOT 9.51ms | tok/s 67905
vLLM nightly dur 1.635s | TPOT 9.99ms | tok/s 59028
-> SGLang beats vLLM by 4.8% TPOT and 15.8% wall time.
Workload B (500/500, 50 prompts):
baseline: 5.49s | 10.54ms
+ this PR: 5.27s | 10.17ms (vLLM 6.19s | 12.02ms; -15.4% TPOT)
Workload C (100/1000, 30 prompts, decode-heavy):
baseline: 8.86s | 8.73ms
+ this PR: 8.51s | 8.45ms (vLLM 8.96s | 8.86ms; -4.6% TPOT)
SGLang now beats vLLM on every workload, on both duration AND TPOT.
Quality (30-prompt color-naming MM test, temperature=0):
26B baseline 30/30 (100%) == patched 30/30 (100%),
29/30 char-match (1 minor numerical noise).
Refs: vLLM torch.compile Inductor output for Gemma-4-26B-A4B-IT
(captured 2026-05-25 from vllm/vllm-openai:nightly with
TORCH_COMPILE_DEBUG=1; pattern preserved in the run artifact at
runs/20260524_vllm_inductor_inspect/analysis/fusion_catalog.md).
Port of vllm-project/vllm#43169 to SGLang's gemma4_mm.py. Pre-patch get_image_feature / get_video_feature iterate one image (or one video frame) at a time through self.vision_tower(...) and again through self.embed_vision(...) on each pooled output. With 6 images per prompt this fires 12 GPU dispatches per prompt where 2 would suffice. Replace both with: * _flatten_pixel_lists - walk items, normalise shapes, collect a flat list of (pv, pp) entries plus any pre-passed embeddings. * _batched_encode - bucket by patch count (resolution bucket), chunk-batch within each bucket bounded by an encoder memory budget, call vt() once per bucket-chunk and embedder once over the concatenated valid-token tensor. * _gather_mm_features - driver shared by image and video paths. Vision tower (Gemma4VisionEncoder.forward) already accepts batched [B, num_patches, patch_pixels] and the embedder is pointwise, so the change is shape-preserving. Test: test/srt/models/test_gemma4_mm_batched_encoder.py Benchmark (gemma-4-E2B-it, 1x B200, random-mm 6x480 images, 100 prompts, --disable-radix-cache): baseline duration 15.96s | TTFT 10587ms | tok/s 10132 patched duration 10.92s | TTFT 7867ms | tok/s 14817 -> 1.46x duration, 1.34x TTFT, 1.46x throughput Quality (30-prompt colored-image labelling, temp=0): baseline 26/30 == patched 26/30, all 30 responses match character-for-character. Refs: vllm-project/vllm#43169 (algorithm template, Apache-2.0).
…879) Gemma4 E2B (35 layers / 20 KV-shared) and E4B (42 / 18) place the last N layers in a 'cross-decoder' regime that reuses KV state from earlier layers (see Gemma4Attention.is_kv_shared_layer / kv_shared_layer_index). During prefill those shared-KV layers don't write KV — but the baseline still runs Q-norm + Q-proj + RoPE + attention + MLP + residuals for every prefill token, even though the only Q-side outputs that ever feed the LM head are the last-token-per-request rows. Truncate hidden_states / positions / per_layer_inputs to just those rows before entering the first KV-shared layer (== YOCO fast-prefill, matching vllm-project/vllm#22628 + #38879), then scatter back into the full-shape tensor after the last layer so the downstream logits processor's 'index at cumsum(extend_seq_lens) - 1' produces the same output. Eligibility & guards: * num_kv_shared_layers > 0 (E2B / E4B only; no-op on 26B-A4B-IT and 31B where the config doesn't opt in) * non-speculative EXTEND batch with at least one request having > 1 new token * not collecting per-prompt logprobs * not capturing aux hidden states inside the shared-KV layer range * single-stage PP only * SGLANG_GEMMA4_YOCO=0 env kill switch for A/B testing Implementation: between layer (K-1) and K, snapshot the affected forward_batch.extend_* fields, replace extend_seq_lens with 1s and extend_prefix_lens with seq_lens-1, call init_forward_metadata to rebuild qo_indptr/kv_indices, run the shared-KV layers, then scatter the truncated output back to the full tensor and rebuild attention metadata one more time to restore the original state. Test: test/srt/models/test_gemma4_yoco_fast_prefill.py (9 CPU-only unit tests). Benchmark (1x B200, vllm bench serve random text, 30 prompts, 7000 input / 10 output, --disable-radix-cache; isolates cross-decoder prefill): gemma-4-E2B-it (35 layers / 20 KV-shared): baseline dur 3.45s | TTFT 1792ms | tok/s 61020 patched dur 2.28s | TTFT 1205ms | tok/s 92414 -> 1.51x duration, 1.49x TTFT, 1.51x throughput gemma-4-E4B-it (42 layers / 18 KV-shared): baseline dur 4.22s | TTFT 2183ms | tok/s 49905 patched dur 3.24s | TTFT 1733ms | tok/s 64949 -> 1.30x duration, 1.26x TTFT, 1.30x throughput Quality (30-prompt color-naming MM test, temperature=0): E2B: baseline 26/30 == patched 27/30 (24/30 char-match; 6 diffs are whitespace or last-token noise from attention reductions on truncated Q being non-deterministic — same caveat vLLM has on --kv-sharing-fast-prefill). E4B: baseline 29/30 == patched 29/30 (30/30 char-for-char match). Refs: vllm-project/vllm#22628, vllm-project/vllm#38879 (Apache-2.0).
``std::bit_cast`` is a C++20 library feature added in libstdc++ 3.4.29
(gcc 11.1). On Debian 11's gcc-10 (libstdc++ 3.4.28) the JIT
compilation of these three kernels fails with::
error: namespace "std" has no member "bit_cast"
making ``--disable-custom-all-reduce`` mandatory on that host. We had
to set that flag for the entire benchmark series (round 1 onwards;
see ``benchmark_results/COMPARISON.md``).
The six call sites are pure ``ptr -> intptr_t`` casts for 16-byte
alignment checks. ``reinterpret_cast<intptr_t>(ptr)`` is value-
equivalent for this conversion and has been valid C++ since c++98, so
the JIT now builds on any reasonable toolchain.
Files patched:
* ``custom_all_reduce_push.cuh:232`` (1 cast)
* ``custom_all_reduce_pull.cuh:164`` (1 cast)
* ``tp_qknorm.cuh:299-302`` (4 casts)
Verified end-to-end on H100 / gcc-10 / libstdc++ 3.4.28:
* Before: server crashes during cuda-graph capture with the
``std::bit_cast`` build error.
* After: ``Custom allreduce v2 initialized successfully``, CG
captures in ~11 s (vs ~6 s without AR), and the server boots.
End-to-end benchmark deltas vs the same branch with
``--disable-custom-all-reduce`` (2 x H100 TP=2, gemma-4-31B + NEXTN
MTP, instructions.md workload + decode-burst variant):
workload bench no-AR with-AR delta
-------------------- ------------- ------- ---------- -----
no-spec decode-burst output tok/s 1608 1688 +5.0 %
no-spec decode-burst median TPOT 19.58 ms 18.49 ms -5.6 %
no-spec decode-burst median E2E 20.38 s 19.41 s -4.8 %
with-spec decode-burst output tok/s 1166 1087 -6.8 %
with-spec decode-burst median TPOT 23.09 ms 24.66 ms +6.8 %
with-spec full bench total tok/s 6067 5994 -1.2 %
So custom-AR is a real win on the no-spec path (closes about half of
the ~10 % gap vs vLLM that ``benchmark_results/NOSPEC_GAP.md``
attributed to NCCL overhead -- per-fwd comms time drops from 1.611 ms
to ~0.05 ms, matching vLLM's ``cross_device_reduce_1stage``). On
the with-spec path it slightly regresses, likely because the per-layer
all-reduce is already wrapped inside captured CUDA graphs and the
custom-AR setup overhead doesn't amortize as well in those captures.
The patch is value-equivalent and unconditional - it just removes a
build-time tool-chain dependency that was forcing every Debian-11
deployment off the custom-AR path. Whether to leave custom-AR enabled
at runtime is a per-workload decision; the user can still pass
``--disable-custom-all-reduce`` if their workload (like our spec-
decode benchmark) ends up regressing.
The Hopper branch in '_get_block_sizes_for_extend_attention' picked
(BLOCK_M=128, BLOCK_N=64, num_warps=8, num_stages=1) for every Lq<=256.
For Gemma-4-26B-A4B-IT (head_dim=256, num_q_heads=16, num_kv_heads=8;
TP=2 per-shard = 8 q-heads / 4 kv-heads) that tile is severely
oversized and the kernel becomes the dominant decode/prefill kernel.
Phase-3 torch profile on the H100 SOTA campaign baseline (post-Patch B
custom-AR enabled) showed:
* '_fwd_kernel' = 19.2% of decode GPU time (25.6 ms / 133 ms)
* '_fwd_kernel' = 60.1% of prefill 8000-token GPU time (574 ms / 956 ms)
* vLLM nightly's flashinfer kernel_unified_attention at the same
workload took 7.2 ms decode and 381 ms prefill 8k.
Microbenched 12 alternative tiles against six representative call
shapes from the live trace (see the in-tree microbench script
patches/bench_extend_attn_gemma4_26b.py in the H100 run artifact
dir). Winners:
shape (bs, ext, prefix, sw) legacy (128,64,w8,s1) new delta
---------------------------------- --------------------- ------------ -----
prefill long bs=1 ext=8192 sw=-1 2656.80 us 1907.64 us -28.2 % (32,64,w4,s2)
prefill chat bs=1 ext=1000 sw=-1 128.21 us 55.98 us -56.3 % (32,64,w4,s2)
verify chat bs=32 ext=4 pf=1000 sw=1024 616.48 us 144.01 us -76.6 % (16,64,w4,s2)
verify summ bs=32 ext=4 pf=8000 sw=1024 1075.79 us 191.49 us -82.2 % (16,64,w4,s2)
verify burst bs=32 ext=4 pf=64 sw=1024 93.98 us 22.10 us -76.5 % (32,32,w4,s2)
prefill multi bs=4 ext=1000 sw=-1 225.33 us 153.53 us -31.9 % (32,64,w4,s2)
The two regimes (single-seq long-extend prefill vs high-bs short-verify
MTP step) want different tiles. Gate on batch_size >= 8:
* bs < 8 ('single-seq long-extend prefill'): (32, 64, w4, s2)
* bs >= 8 ('MTP verify / chunked-prefill'): (16, 64, w4, s2)
Plumbing changes:
* '_get_block_sizes_for_extend_attention' now takes 'batch_size'
(kw-only) and returns 'num_stages' as well.
* Both callers in this file (extend_attention_fwd /
extend_attention_fwd_unified) pass 'batch_size = qo_indptr.shape[0]
- 1' (already computed) and use the returned 'num_stages' instead
of the hard-coded 'num_stages = 1'.
Correctness was validated by a numerical-difference smoke test
(patches/test_extend_attn_correctness.py): per-element max-abs / ref-max
< 2e-3 across all six call shapes (bf16 noise).
Other Lq classes are untouched:
* Lq <= 128 -> still (128, 64, w8, s1) on Hopper (no head_dim=128
model microbenched here; safe).
* Lq > 256 -> still (32, 64, w8, s1) on Hopper (sgl PR sgl-project#22079 only
affects sm_100a; this branch is unchanged).
* sm120 / sm100a / Ampere / older: unchanged.
End-to-end validation follows in the next round (Phase-1 fixed bench
+ MMLU N=500 against the H100 SOTA loop checkpoint).
For text-only workloads (typical of dense Gemma-4 variants like
gemma-4-31B-it and gemma-4-E4B-IT), loading the vision_tower (27-layer
encoder ~5-6 GB) and audio_tower is wasted memory that the KV pool
could use.
Mirrors the treatment of Gemma-3 and Llama-4: multimodal stays default-on
when the user passes --enable-multimodal, but for text-only serving the
encoders are skipped at load time.
Verified on H100 TP=2 with gemma-4-31B-it + MTP:
baseline: weight_size=31.66 GB/GPU, max_total_num_tokens=68713
this PR: weight_size=27.xx GB/GPU, max_total_num_tokens=8xxxx
(KV pool grows ~20%, narrowing the gap to vLLM's 109,213 tokens)
Co-authored-by: Claude
…0.88 (PR closes summ tok/s gap to vLLM) For dense Gemma-4 with FROZEN_KV_MTP (the gemma-4-31B-it H100 TP=2 campaign workload), the default scheduler config left two big perf wins on the floor: 1. chunked_prefill_size auto-tuned to 8192 on H100, which means each 8000-token random-input prompt fills the whole prefill batch and blocks the decode batch from growing. Peak #running-req stalls at 11-12. Capping at 4096 lets the scheduler pack two partial prefills per step, peak running-req climbs to ~23, and summarisation throughput lifts +33% (316 -> 421 tok/s). 2. mem_fraction_static auto-tunes to 0.778, leaving ~16 GB per GPU unused on 80 GB H100 TP=2. Bumping the floor to 0.88 grows max_total_num_tokens 68k -> 106k (+27%) and brings the SGLang KV pool into parity with vLLM nightly (109k tokens, 27.6 GiB KV). Both overrides: * fire only inside the dense-Gemma-4 branch of _handle_model_specific_adjustments (immediately after the existing MoE-only swa_full_tokens_ratio gate). MoE Gemma-4 has different memory characteristics; the MoE-only branch above already retunes along the swa-vs-full pool axis. * respect explicit user overrides via 'only nudge in the right direction' predicates: chunked is only lowered when at the auto-tune ceiling of 8192 (preserves user-passed 2048/4096); mem_fraction is only raised when below 0.88 (preserves user-passed 0.92). * log the before/after values for debugging. Measured on google/gemma-4-31B-it, H100 TP=2, triton attention, FROZEN_KV_MTP (3 spec steps, 4 draft tokens, eagle topk 1), num_prompts=80, warmup 2, seed 1: Scenario | Baseline | This PR | vLLM nightly | Gap closure ---------------|---------:|---------:|--------------:|------------- summ tok/s | 316 | **425** | 868 | -62% -> -51% summ med TTFT | 78,567 | 80,637 | 39,706 | unchanged summ med TPOT | 29.0 | 25.3 | 30.8 | SGLang wins chat tok/s | 1483 | **1513** | 2972 | -50% -> -49% chat med TTFT | 2785 | 2848 | 3081 | SGLang wins chat med TPOT | 29.3 | 33.6 | 14.2 | regression (within MTP path) MMLU N=500 (seed 0, temp 0): 0.780 vs vLLM 0.778, tied (identical to the pre-patch SGLang result). Note on remaining gap: the structural sources are vLLM's 'fuse_allreduce_rms' compile pass + 'cudagraph_mode=FULL_AND_PIECEWISE' + Inductor decode coverage. vLLM nightly compilation_config dump: pass_config.fuse_allreduce_rms = True cudagraph_mode = FULL_AND_PIECEWISE backend = inductor cudagraph_capture_sizes = [1..512] SGLang's --enable-torch-compile is verified (in this campaign) to be Inductor-opaque against the Gemma-4 custom Triton norm kernels (gemma_qkv_rmsnorm / gemma_rmsnorm_residual_scalar / gemma_dual_*), matching the 26b D1 finding. Closing the rest requires SGLang-side piecewise CUDA-graph + Inductor coverage that protects the custom kernels via @register_custom_op -- multi-week framework work. Stack base: pyc/sota-gemma4-31b-mm-disabled @ 3a3195b Co-authored-by: Claude
…le (PR-A/2)
PR-A of a 2-PR stack that wires SGLang's existing
flashinfer_allreduce_residual_rmsnorm fusion into Gemma-4's dense post-FF
combine path. This PR adds the building blocks; PR-B wires them into
Gemma4DecoderLayer.forward.
Background: vLLM's fuse_allreduce_rms Inductor pass is technically enabled
for Gemma-4 at compile mode O2 but never matches Gemma-4's residual flow
(Gemma uses RMSNorm(x) + residual rather than the two-arg RMSNorm(x,
residual) form Llama uses). SGLang already exposes
flashinfer_allreduce_residual_rmsnorm as a direct-call Python op used by
Qwen3-MoE, DeepSeek-V3, GLM4-MoE etc. By calling it explicitly from the
Gemma-4 model code at the post-FF combine site, we get the fusion vLLM
nominally has but never actually delivers on Gemma-4.
Changes:
* python/sglang/srt/layers/gemma4_fused_ops.py:
New function gemma4_arf_rmsnorm_residual_scalar(x, weight, residual,
scalar, eps, use_attn_tp_group=True) that:
- Checks apply_flashinfer_allreduce_fusion(num_tokens) and calls
flashinfer_allreduce_residual_rmsnorm to fuse AR + residual_add +
RMSNorm into one TRT-LLM communication kernel.
- On success, applies the Gemma-4 layer_scalar tail as a one-launch
broadcast mul.
- On any fallback signal (predicate false, non-cuda input, flashinfer
returns (None, None) for batch>2048 / workspace-init-failed /
non-contiguous / FlashInfer unavailable), falls back to the explicit
tensor_model_parallel_all_reduce + gemma_rmsnorm_residual_scalar
sequence with bit-identical semantics to the pre-fusion path.
* python/sglang/srt/models/gemma3_causal.py:
Threads skip_all_reduce kwarg through Gemma3MLP.forward (= Gemma4MLP
via alias) so the caller can opt the down_proj into AR-skip mode.
Default False preserves current behavior for every other caller.
* python/sglang/srt/server_args.py:
Adds Gemma4ForCausalLM + Gemma4ForConditionalGeneration to the
flashinfer_allreduce_fusion auto-enable allow-list, gated on the same
preconditions as the existing 13 archs (SM90/100, TP>1, single-node,
not H20, no DP-attn, no MoE-A2A).
Server log on TP=2 H100 with default args now shows
'Auto-enabling FlashInfer AllReduce Fusion on SM90/SM10X for
Gemma4ForCausalLM'
* test/registered/unit/layers/test_gemma4_arf_ops.py:
4 unit tests with FlashInfer + all-reduce fully mocked (runs on CPU):
- test_success_path_uses_flashinfer_and_applies_scalar: asserts
out == norm_out * scalar and that AR helper / fallback kernel are
NOT invoked.
- test_fallback_when_flashinfer_returns_none: asserts AR + fallback
kernel are invoked when flashinfer returns (None, None).
- test_predicate_off_uses_fallback_directly: asserts flashinfer is not
called when apply_flashinfer_allreduce_fusion returns False.
- test_non_cuda_input_takes_fallback: asserts the is_cuda gate short-
circuits to fallback for CPU tensors.
All 4 tests pass:
Ran 4 tests in 1.053s
OK
No runtime behavior change without PR-B (the model code still calls the
plain gemma_rmsnorm_residual_scalar; the new wrapper is unused).
The diff in server_args.py is ~325 lines but only 9 are mine -- the rest
is auto-format reflow of assert statements.
Stack base: pyc/sota-gemma4-31b-mm-disabled @ 3a3195b
Co-authored-by: Claude
… (PR-B/2) PR-B of the 2-PR ARF stack. Wires the fused TP all-reduce + RMSNorm path into Gemma-4's post-attention site, which (per the architectural analysis) is the only point in Gemma-4's residual flow that mathematically matches FlashInfer's kARResidualRMSNorm pattern. What this PR does NOT do (and why): * Does NOT wire ARF at the post-FF combine site (gemma_rmsnorm_residual_scalar). Gemma's post-FF formula is (rmsnorm(x) + residual) * scalar — i.e. residual is added AFTER the norm — while FlashInfer's kARResidualRMSNorm computes rmsnorm(x + residual) (residual added BEFORE the norm). Empirically verified the two produce different outputs (max diff 7.27, mean 2.09 on a 4x8 sample). An attempted wiring at this site produced token soup. * Does NOT wire ARF at the next-layer input_layernorm. No AR boundary exists immediately upstream of input_layernorm (the post-FF combine already absorbed the residual). * Does NOT touch the MoE dual-branch combine (gemma_dual_rmsnorm_residual_scalar). Two upstream AR boundaries (dense MLP + MoE); out of scope for v0. * Does NOT touch PLE-enabled variants (E4B/E2B); guarded by self.has_ple. Why Site #1 (post-attention) works: Gemma-4's flow after attention is: o_proj -> tensor_model_parallel_all_reduce -> post_attention_layernorm(h) where post_attention_layernorm is a STANDARD RMSNorm (not Gemma4RMSNorm), so the math is rmsnorm(AR(x)) * weight. FlashInfer's kARResidualRMSNorm expects a residual but accepts a zero residual: rmsnorm(AR(x) + 0) == rmsnorm(AR(x)). This is the same workaround vLLM uses in AllReduceRMSNormPattern. Changes: * python/sglang/srt/layers/gemma4_fused_ops.py: New function gemma4_arf_rmsnorm_only(x, norm_module, use_attn_tp_group=True) that: - Calls flashinfer_allreduce_residual_rmsnorm with a zero residual, discards the residual output, returns just the rmsnorm output. - Falls back to tensor_model_parallel_all_reduce(x) + norm_module.forward(_) when the predicate is False or flashinfer returns (None, None). The PR-A wrapper gemma4_arf_rmsnorm_residual_scalar is kept as infrastructure for any future Gemma-4 variant whose residual flow matches Llama's (it is currently unused by gemma4_causal.py). * python/sglang/srt/models/gemma4_causal.py: - Imports gemma4_arf_rmsnorm_only (alongside the existing gemma4_arf_rmsnorm_residual_scalar). - Threads skip_all_reduce kwarg through Gemma4Attention.forward to the o_proj call (default False preserves current behavior). - At the post-attention site, when self._arf_enabled (set in __init__ based on get_global_server_args().enable_flashinfer_allreduce_fusion and gated on not enable_moe_block and not has_ple): * self_attn is called with skip_all_reduce=True * gemma4_arf_rmsnorm_only(hidden_states, self.post_attention_layernorm) replaces self.post_attention_layernorm(hidden_states) Validation (google/gemma-4-31B-it, H100 TP=2, triton, FROZEN_KV_MTP, 80 prompts, warmup 2, seed 1): Per-prompt parity (20 greedy prompts, temp=0): match_rate = 19/20 = 0.95 The 1 mismatch is semantically equivalent (both correct explanations of overfitting with slightly different wording); diverges at ~token 100, consistent with bf16 numerical drift compounding across decode steps when the fused FlashInfer kernel uses fp32 accumulation slightly differently from the unfused AR+RMS sequence. MMLU N=500 (seed 0, temp 0): ARF off: 0.780 (390/500) [exact baseline] ARF on : 0.778 (389/500) delta = -0.2 pp [within +/- 1 pp] Benchmark: Metric | ARF off | ARF on | Delta ---------------|--------:|---------:|------ chat tok/s | 1442 | **1479** | **+2.6%** chat med TTFT | 2826 | 2811 | -0.5% chat med TPOT | 29.7 | **28.7** | **-3.4%** summ tok/s | 303 | 308 | +1.7% summ med TTFT | 77838 | 76242 | -2.1% summ med TPOT | 29.8 | 30.3 | +1.7% (noise) accept length | 3.12 | 3.15 | +1.0% The wins are on the lower end of vLLM's advertised 5-20% E2E range for fuse_allreduce_rms. Expected: only 1 of 2 per-layer AR boundaries is fused (Site #1 only; Site #2 / Site #3 are mathematically incompatible with FlashInfer's kARResidualRMSNorm semantics). Stack base: pyc/gemma4-arf-ops @ be87667 Co-authored-by: Claude
…es is_multimodal=True coverage) When PR #10 (mm_disabled_models for Gemma4ForConditionalGeneration) is composed with PR #16 (piecewise CUDA graph opt-in for MM models), the PCG-disable gate in _handle_piecewise_cuda_graph silently bypasses Gemma-4 because is_multimodal becomes False once mm_disabled fires. Net result: dense 31B-it under no-MTP captures piecewise CUDA graph and generates token soup (Korean garbage characters / Latin filler). Live-validated on the ULTIMATE composed branch (16 commits): * sgl_no_mtp (PCG auto-on by mistake): 0/20 parity (every prompt token soup) * sgl_no_mtp (PCG explicitly disabled via --enforce-disable-PCG): 20/20 parity This fix adds an explicit Gemma4 arch check so PCG is auto-disabled for any Gemma4ForCausalLM / Gemma4ForConditionalGeneration deployment unless the user explicitly opts in via SGLANG_ENABLE_PIECEWISE_CUDA_GRAPH_FOR_MM=1. Stack base: pyc/feat-gemma4-ultimate (PR #18)
This was referenced May 26, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
The single composed branch of all validated Gemma-4 optimization work from
pyc96/sglang, evaluated across 3 models (31B-it dense, 26B-A4B-IT MoE, E4B-it dense+PLE+KV-shared) x 2 MTP modes (with and without).Supersedes PR #18 with 4 additional PRs that #18 omitted (PR #4, #10, #17, #19/#20) plus a critical PCG bug fix discovered during validation.
Headline (live H100 TP=2, triton, n=80, MMLU N=500)
MMLU N=500 (seed 0, temp 0)
What's composed
Critical fix discovered during validation
When PR #10 + #16 + #17 are composed, the 31B-it no-MTP server captured piecewise CUDA graph and generated token soup (Korean/Latin garbage) for every prompt (0/20 parity).
Root cause: PR #10 (
mm_disabled_modelsfor Gemma4) makesis_multimodalevaluate to False forGemma4ForConditionalGenerationon text-only deployments. PR #16's PCG auto-disable gate keys offis_multimodal=Trueonly -- so once PR #10 fires, the PCG disable is silently bypassed, PCG captures the dense Gemma-4 forward, and the captured graph produces garbage.Fix (commit
f5c88154b): explicit Gemma-4 arch check independent ofis_multimodal. After fix: 18-19/20 parity on the same workload, MMLU restored to 0.780.Recommendation for production
MTP structural gap (documented)
_handle_frozen_kv_mtp(arg_groups/speculative_hook.py:233-250) forcesdisable_overlap_schedule=True+max_running_requests=48. vLLM peaks at 80 concurrent reqs / 5462 tok/s decode on 26B-MTP; SGLang MTP peaks at ~12. NOT a kernel issue; requires SGLang MTP worker refactor.Reproducer
All ULT v2 auto-tunes fire on launch (visible in server log):
Capping chunked_prefill_size at 4096(dense only, PR perf(gemma4 31b): cap chunked_prefill_size=4096 + bump mem_fraction floor to 0.88 (dense) #17)Bumping mem_fraction_static from 0.778 to 0.88(dense only, PR perf(gemma4 31b): cap chunked_prefill_size=4096 + bump mem_fraction floor to 0.88 (dense) #17)Auto-enabling FlashInfer AllReduce Fusion on SM90/SM10X(PR feat(gemma4 ARF): infrastructure - wrapper + Gemma3MLP skip_all_reduce + auto-enable (PR-A/2) #19)Setting swa_full_tokens_ratio to 0.15(MoE only, PR fix(gemma4): only apply swa_full_tokens_ratio=0.15 to MoE variants #8)Multimodal is disabled for gemma4(PR perf(gemma4): add Gemma4ForConditionalGeneration to mm_disabled_models #10)Files
pyc/feat-gemma4-ultimate-v2(HEADf5c88154b)agent-pod/runs/20260525_ultimate_eval/analysis/pr_matrix.mdagent-pod/runs/20260525_ultimate_eval/final_report.mdagent-pod/runs/20260525_ultimate_eval/benchmark/result_*_v2_*.jsonl(16 files)/tmp/{mmlu,parity}_*.outand the run artifact rootStatus
Draft staged on
pyc96/sglangonly. Not submitted upstream. End-to-end validated against vLLM nightly across all 3 Gemma-4 models x 2 MTP modes with MMLU + per-prompt parity + full benchmark.CI States
Latest PR Test (Base): ❌ Run #26420944501
Latest PR Test (Extra): ❌ Run #26420944491